import random
import torch
from sklearn.model_selection import train_test_split
# from .hogrl_mode_dual import *
from .hogrl_utils import *
# from .hogrl_utils import *
import numpy as np
import random as rd
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, average_precision_score
from torch_geometric.utils import degree, to_undirected, dropout_adj
import numpy as np
from scipy.sparse import csr_matrix
import networkx as nx
import logging
import os
from datetime import datetime
# from .kmeans import kmeans
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.decomposition import PCA
import pickle  
import time
import gudhi as gd  # 导入GUDHI库用于持久同调计算
import ot  # 导入POT库用于Wasserstein距离计算
from sklearn.manifold import TSNE
import torch.nn as nn
import torch.nn.functional as F
from scipy import stats

class GradientAwareFocalLoss(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=2.0, gamma_ga=0.5, gamma_grad=1.0, use_softmax=True):
        super(GradientAwareFocalLoss, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal
        self.gamma_ga = gamma_ga
        self.gamma_grad = gamma_grad  # 控制梯度权重的强度
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B  # 总样本数

        # 1. 计算概率和基础损失
        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        # 2. 启用梯度计算（关键步骤！）
        inputs_grad = inputs.detach().requires_grad_(True)  # 保留梯度计算图
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True  # 保留计算图以支持后续反向传播
        )[0]  # 梯度形状与inputs相同 (B, C, ...)

        # 3. 计算梯度幅度（L2范数）
        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)  # (N_total,)
        grad_weight = (grad_magnitude + 1e-8) ** self.gamma_grad  # 避免零梯度

        # 4. 动态类别平衡（与原实现一致）
        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        # 5. 三重权重耦合：Focal + Class + Gradient
        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]
        #final_weight = focal_weight * class_weight * grad_weight  # 关键融合点

        # step 1: class-aware difficulty
        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        # step 2: sample-level hardness (focal)
        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        # 6. 最终损失
        loss = (final_weight * ce_loss).mean()
        return loss
    

# 添加LPL相关的类
def get_step(split: int, classes_num: int, pgd_nums: int, classes_freq: list):
    """计算每个类别的步数，基于类别频率"""
    step_size = pgd_nums*0.1
    class_step = []
    for i in range(0, classes_num):
        if i < split:
            step = (classes_freq[i] / classes_freq[0]) * step_size - 1
        else:
            step = (classes_freq[i] / classes_freq[-1]) * step_size - 1
        class_step.append(round(step))
    class_step = [0 if x < 0 else x for x in class_step]
    class_step = [pgd_nums+x for x in class_step]
    return class_step

class LPLLoss_advanced(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1, min_class_factor=3.0):
        """
        升级版自适应LPL损失实现
        
        Args:
            num_classes: 类别数量
            pgd_nums: 基础PGD扰动的步数
            alpha: 基础扰动强度
            min_class_factor: 少数类最小扰动系数，保证少数类扰动强度至少为多数类的这个倍数
        """
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.min_class_factor = min_class_factor
        self.criterion = nn.CrossEntropyLoss()
        
        # 记录类别不平衡和梯度状态
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9  # 动量因子
    
    def update_statistics(self, logit, y):
        """更新类别统计信息和梯度幅度"""
        with torch.no_grad():
            # 更新类别计数
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            # 估计每个类别的梯度幅度
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                n_samples = torch.sum(class_mask)
                
                if n_samples > 0:
                    # 获取该类别样本的logits
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    
                    # 计算样本损失，作为梯度幅度估计
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
            # 使用动量更新梯度幅度
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):
        """计算自适应扰动参数"""
        with torch.no_grad():
            # 更新统计信息
            self.update_statistics(logit, y)
            
            # 获取类别分布信息
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
            # 找出少数类和多数类
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx  # 在二分类情况下
            
            # 计算类别不平衡比
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)

            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            
            # 根据梯度幅度动态调整扰动强度，梯度大的类别获得更强的扰动
            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            
            # 类别步数和扰动强度
            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            
            # 设置步数范围
            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            
            # 基于类别频率反比例计算步数
            for c in range(self.num_classes):
                # 样本越少，步数越多
                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                
                # 扰动强度：基于梯度幅度和类别频率
                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)  # 梯度大的类别获得更强的扰动
                
                # 少数类得到额外的强度提升
                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
            
            # 确保少数类的步数至少是多数类的1.5倍
            if class_steps[minority_idx] < class_steps[majority_idx] * 1.5:
                class_steps[minority_idx] = int(class_steps[majority_idx] * 1.5)
            
            # 确保少数类的扰动强度至少是多数类的min_class_factor倍
            if class_alphas[minority_idx] < class_alphas[majority_idx] * self.min_class_factor:
                class_alphas[minority_idx] = class_alphas[majority_idx] * self.min_class_factor
            
            # 为每个样本分配步数和扰动强度
            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            
            # 根据样本的类别分配参数
            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            
            # 样本级别的梯度感知调整
            with torch.enable_grad():
                # 创建副本并跟踪梯度
                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')
                
                # 计算梯度
                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                # 使用梯度幅度作为难度指标
                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                
                # 将难度因子映射到[0.8, 1.5]的范围
                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                
                # 应用到样本的扰动参数
                sample_alphas = sample_alphas * difficulty_scales
                
                # 步数也可以根据难度适当调整
                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas
    
    def compute_adv_sign(self, logit, y, sample_alphas):
        """计算自适应对抗梯度方向"""
        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            
            # 计算每个类别的平均logit
            sum_class_logit = torch.matmul(
                y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            
            # 防止类别不存在导致除零
            sum_class_num = torch.where(sum_class_num == 0, 100, sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            
            # 计算扰动梯度方向
            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            
            # 计算扰动方向标志
            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            
            # 使用样本自适应扰动强度
            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub
    
    def compute_eta(self, logit, y):
        """计算最终的自适应扰动"""
        with torch.no_grad():
            # 计算自适应参数
            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            
            # 最大可能步数
            max_steps = torch.max(sample_steps).item()
            
            # 记录每步扰动后的结果
            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)
            
            # 初始状态
            current_logit = logit.clone()
            logit_steps[0] = current_logit
            
            # 迭代应用扰动
            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            
            # 为每个样本选择对应步数的结果
            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            
            # 计算扰动
            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas
    
    def forward(self, models_or_logits, x=None, y=None, is_logits=False):
        """前向传播函数"""
        if is_logits:
            # 直接使用预计算的logits
            logit = models_or_logits
        else:
            # 使用模型计算logits
            logit = models_or_logits(x)
        
        # 计算自适应扰动
        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        
        # 应用扰动
        logit_news = logit + eta
        
        # 计算损失
        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news, sample_steps, sample_alphas

def visualize_clustering(embeddings, pseudo_labels, high_confidence_idx, epoch, save_path, overwrite_previous=True):
    """可视化聚类结果，只显示高置信度样本，使用3D可视化
    
    Args:
        embeddings: 节点嵌入向量
        pseudo_labels: 聚类标签
        high_confidence_idx: 高置信度样本的索引
        epoch: 当前epoch
        save_path: 保存路径
        overwrite_previous: 是否覆盖之前的图片，默认为True
    """
    # 确保保存目录存在
    save_dir = '/root/autodl-tmp/hali/antifraud/log_zp4/fig'
    os.makedirs(save_dir, exist_ok=True)
    
    # 使用PCA降维到3D
    pca = PCA(n_components=3)
    embeddings_3d = pca.fit_transform(embeddings.cpu().numpy())
    
    # 将pseudo_labels转换为numpy数组
    pseudo_labels_np = pseudo_labels.cpu().numpy()
    
    # 调整标签：确保数量多的类别为标签0
    if high_confidence_idx is not None:
        high_conf_labels = pseudo_labels_np[high_confidence_idx.cpu().numpy()]
        label_counts = np.bincount(high_conf_labels)
        if len(label_counts) == 2 and label_counts[1] > label_counts[0]:
            # 如果标签1的数量更多，则交换标签
            pseudo_labels_np = 1 - pseudo_labels_np
    
    # 创建3D图形
    plt.figure(figsize=(10, 8))
    ax = plt.axes(projection='3d')
    
    # 使用橙色和蓝色作为颜色映射
    colors = ['#FF7F0E', '#1F77B4']  # 橙色和蓝色
    custom_cmap = ListedColormap(colors)
    
    # 设置背景颜色为白色
    ax.set_facecolor('white')
    ax.grid(False)
    
    if high_confidence_idx is not None:
        # 只获取高置信度样本的数据
        high_conf_3d = embeddings_3d[high_confidence_idx.cpu().numpy()]
        high_conf_labels = pseudo_labels_np[high_confidence_idx.cpu().numpy()]
        
        # 计算每个标签的高置信度样本数量
        label_0 = np.sum(high_conf_labels == 0)
        label_1 = np.sum(high_conf_labels == 1)
        total_samples = len(high_conf_labels)
        
        # 绘制3D散点图
        scatter = ax.scatter(high_conf_3d[:, 0], high_conf_3d[:, 1], high_conf_3d[:, 2],
                   c=high_conf_labels,
                   cmap=custom_cmap,
                   s=15,  # 点的大小
                   alpha=0.6,  # 透明度
                   marker='o')  # 圆形标记
        
        # 设置标题
        ax.set_title(f'(Epoch {epoch})\n'
                  f'High confidence samples: {total_samples}\n'
                  f'Negative examples: {label_0} | Positive examples: {label_1}')
    

    # 添加图例
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#FF7F0E', 
               markersize=8, label='Negative examples'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#1F77B4', 
               markersize=8, label='Positive examples')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    # 设置最佳视角
    ax.view_init(elev=20, azim=45)
    
    # 调整图形边距
    plt.tight_layout()
    
    # 根据overwrite_previous决定文件名格式
    if overwrite_previous:
        save_file = os.path.join(save_dir, f'{save_path}_latest.png')
    else:
        save_file = os.path.join(save_dir, f'{save_path}_epoch{epoch}.png')
    
    # 保存图片
    plt.savefig(save_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()
    
    print(f"聚类可视化已保存到: {save_file}")

def test(idx_eval, y_eval, gnn_model, feat_data, edge_indexs):
    """测试函数
    
    Args:
        idx_eval: 评估用的节点索引
        y_eval: 评估用的标签
        gnn_model: GNN模型
        feat_data: 节点特征
        edge_indexs: 图结构
    """
    gnn_model.eval()
    # 修改这里，增加 _ 来接收额外的返回值
    logits, _= gnn_model(feat_data, edge_indexs)
    x_softmax = torch.exp(logits).cpu().detach()
    positive_class_probs = x_softmax[:, 1].numpy()[np.array(idx_eval)]
    
    # 计算总体AUC
    y_eval_np = np.array(y_eval)
    auc_score = roc_auc_score(y_eval_np, np.array(positive_class_probs))
    
    # 获取标签0的概率 (1 - 标签1的概率)
    negative_class_probs = 1 - positive_class_probs
    
    # 创建标签0和标签1的二分类问题
    y_eval_label0 = (y_eval_np == 0).astype(int)  # 是否为标签0
    y_eval_label1 = (y_eval_np == 1).astype(int)  # 是否为标签1
    
    # 计算各个标签的AUC
    auc_score_label0 = roc_auc_score(y_eval_label0, negative_class_probs)
    auc_score_label1 = roc_auc_score(y_eval_label1, positive_class_probs)

    # 计算预测标签
    label_prob = (np.array(positive_class_probs) >= 0.5).astype(int)
    
    # 计算总体准确率
    acc_overall = accuracy_score(y_eval_np, label_prob)
    
    # 计算每个标签的准确率
    # 对于标签0：计算被正确预测为0的样本所占总标签0样本的比例
    if np.sum(y_eval_np == 0) > 0:  # 防止除以零
        acc_label0 = np.sum((y_eval_np == 0) & (label_prob == 0)) / np.sum(y_eval_np == 0)
    else:
        acc_label0 = 0.0
        
    # 对于标签1：计算被正确预测为1的样本所占总标签1样本的比例
    if np.sum(y_eval_np == 1) > 0:  # 防止除以零
        acc_label1 = np.sum((y_eval_np == 1) & (label_prob == 1)) / np.sum(y_eval_np == 1)
    else:
        acc_label1 = 0.0
    
    ap_score = average_precision_score(np.array(y_eval), np.array(positive_class_probs))
    f1_score_val = f1_score(np.array(y_eval), label_prob, average='macro')
    g_mean = calculate_g_mean(np.array(y_eval), label_prob)

    return auc_score, ap_score, f1_score_val, g_mean, acc_label0, acc_label1, acc_overall

def sigmoid_rampup(current, rampup_length):
    '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))

def get_current_mu(epoch, args):
    if args['mu_rampup']:
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        if args['consistency_rampup'] is None:
            #args['consistency_rampup'] = args['num_epochs']
            args['consistency_rampup'] = 500
        return args['mu'] * sigmoid_rampup(epoch, args['consistency_rampup'])
    else:
        return args['mu']

def initialize_centroids(features, k):
    """使用k-means++策略初始化聚类中心"""
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    
    # 随机选择第一个中心
    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    # 选择剩余的中心
    for i in range(1, k):
        # 计算到最近中心的距离
        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]
        # 按概率选择下一个中心
        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):
    """检查聚类是否收敛"""
    return torch.norm(centroids - prev_centroids) < tol

# 新增：鲁棒节点聚类方法
def robust_node_clustering(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """基于论文的鲁棒节点聚类方法
    
    Args:
        features: 原始图的节点特征 [num_nodes, feature_dim]
        k: 聚类数量(默认2，对应二分类)
        temperature: 温度参数，控制软分配的软硬程度
        max_iterations: 最大迭代次数
        labeled_features: 有标签样本的特征 [num_labeled, feature_dim]
        labeled_classes: 有标签样本的标签 [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: 原始图的聚类分配 [num_nodes, k]
            view1_cluster_assignments: 增强视图1的聚类分配 [num_nodes, k]
            view2_cluster_assignments: 增强视图2的聚类分配 [num_nodes, k]
            centroids: 聚类中心 [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device

    
    # 聚类迭代过程不需要梯度，使用no_grad包裹
    with torch.no_grad():
        # 检查是否提供了有标签样本作为聚类中心
        if labeled_features is not None and labeled_classes is not None:
            # 使用有标签样本初始化聚类中心
            centroids = torch.zeros(k, feature_dim, device=device)
            
            # 按类别分组有标签样本
            for i in range(k):
                # 找到标签为i的样本
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # 如果有该类的样本，计算这些样本的平均特征作为中心
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # 如果没有该类的样本，随机初始化
                    centroids[i] = torch.randn(feature_dim, device=device)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # 归一化
                    
            # 规范化聚类中心 - 确保它们具有相同的范数
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # 避免除以零
            
        else:
            # 如果没有提供有标签样本，使用原始的k-means++初始化策略
            # 注意：只使用原始图特征进行中心初始化
            centroids = initialize_centroids(features, k)
        
        # 记录初始的聚类中心用于检查收敛
        prev_centroids = centroids.clone()
        
        # 只有在没有提供标签数据时才进行迭代优化
        if labeled_features is None or labeled_classes is None:
            # 迭代优化 - 完全不需要梯度
            for iter in range(max_iterations):
                # 计算每个节点到各个聚类中心的距离 - 只使用原始图特征
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # 软分配 (使用Gumbel-Softmax进行可微分的聚类分配)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # 更新聚类中心 - 只使用原始图特征
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # 避免除以零
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # 保持原来的中心
                
                # 使用新的张量替代原有张量
                centroids = new_centroids
                    
                # 检查收敛
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # 重新计算最终的聚类分配（在梯度环境下使用不同视图的features，保留梯度）
    # 为原始图特征计算聚类分配
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
   
    view1_cluster_assignments = original_cluster_assignments

    
    view2_cluster_assignments = original_cluster_assignments

    # 计算聚类结果的统计信息
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids

def compute_clustering_loss(features, cluster_assignments, centroids, epsilon=1e-6):
    features = F.normalize(features, p=2, dim=1)
    centroids = F.normalize(centroids, p=2, dim=1)
    
    with torch.no_grad():
        hard_assignments = torch.argmax(cluster_assignments, dim=1)
        pos_indices = torch.nonzero(hard_assignments == 1).squeeze(-1)
        neg_indices = torch.nonzero(hard_assignments == 0).squeeze(-1)
        num_pos = pos_indices.numel()
        num_neg = neg_indices.numel()
        total = num_pos + num_neg + epsilon
        pos_weight = num_neg / total if num_pos > 0 else 0.0
        neg_weight = num_pos / total if num_neg > 0 else 0.0

    distances = torch.cdist(features, centroids)  # [N, K]
    intra_positive_loss = torch.mean(distances[pos_indices, 1]) if num_pos > 0 else torch.tensor(0.0, device=features.device)
    intra_negative_loss = torch.mean(distances[neg_indices, 0]) if num_neg > 0 else torch.tensor(0.0, device=features.device)
    intra_loss =  intra_positive_loss +  intra_negative_loss

    centroid_dists = torch.pdist(centroids)
    inter_loss = -torch.mean(centroid_dists)

    expanded_centroids = torch.index_select(centroids, 0, hard_assignments)
    compactness = torch.mean(torch.sum((features - expanded_centroids) ** 2, dim=1))
    joint_reg = compactness / (torch.mean(centroid_dists) + epsilon)

    total_loss = 0.5 * intra_loss + 0.5 * inter_loss + 0.1 * joint_reg
    return total_loss, num_pos, num_neg




def hogrl_main(args):
    # 设置设备
    if torch.cuda.is_available() and args['gpu_id'] >= 0:
        device = torch.device(f"cuda:{args['gpu_id']}")
    else:
        device = torch.device('cpu')
    
    # 添加伪标签平衡控制参数，默认为False
    balance_pseudo_labels = args.get('balance_pseudo_labels', False)
    
    # 添加伪标签算法开关变量
    # 控制是否使用原始图的输出直接计算伪标签
    use_original_pseudo_labels = args.get('use_original_pseudo_labels', True)
    # 控制是否使用聚类结果生成伪标签
    use_clustering_pseudo_labels = args.get('use_clustering_pseudo_labels', True)
    
    # 修改日志文件命名
    #timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_file_path = f'/root/autodl-tmp/hali/antifraud/log_zp4/output_v4_2.log'
    
    def log_and_print(message):
        print(message)  # 打印信息到控制台
        # 确保logs目录存在
        os.makedirs('logs', exist_ok=True)
        with open(log_file_path, 'a') as file:
            file.write(message + '\n')
    
    # 添加聚类中心固定的配置信息
    fixed_cluster_epochs = 10  # 前10个epoch固定聚类中心
    log_and_print(f"\n【聚类初始化策略】")
    log_and_print(f"  前{fixed_cluster_epochs}个epoch: 使用有标签样本固定聚类中心")
    log_and_print(f"  第{fixed_cluster_epochs}个epoch之后: 使用自由聚类")
    
    # 记录伪标签策略到日志
    log_and_print(f"\n【伪标签策略配置】")
    log_and_print(f"  使用原始图输出生成伪标签: {'启用' if use_original_pseudo_labels else '禁用'}")
    log_and_print(f"  使用聚类结果生成伪标签: {'启用' if use_clustering_pseudo_labels else '禁用'}")
    if use_original_pseudo_labels and use_clustering_pseudo_labels:
        log_and_print(f"  策略: 两种方法的softmax结果相加除以2，然后决定伪标签")
        log_and_print(f"  损失函数: GradientAwareFocalLoss")
    elif use_original_pseudo_labels:
        log_and_print(f"  策略: 仅使用原始图输出生成伪标签")
        log_and_print(f"  损失函数: 普通交叉熵损失 (CrossEntropyLoss)")
    elif use_clustering_pseudo_labels:
        log_and_print(f"  策略: 仅使用聚类结果生成伪标签")
        log_and_print(f"  损失函数: GradientAwareFocalLoss")
    else:
        log_and_print(f"  警告: 两种伪标签方法均已禁用，将不使用伪标签")
    
    # 初始化GradientAwareFocalLoss损失函数
    gradient_aware_focal = GradientAwareFocalLoss(num_classes=2,
                                                  k_percent=10,
                                                  gamma_focal=2,
                                                  gamma_ga=0.5,
                                                  gamma_grad=1,
                                                  use_softmax=True).to(device)
    # 初始化自适应LPL损失函数
    adaptive_lpl_loss = LPLLoss_advanced(
        num_classes=2,
        pgd_nums=30,
        alpha=0.05,
        min_class_factor=3.5
    ).to(device)
    
    # 记录关键损失函数配置
    log_and_print(f"\n【损失函数配置】")
    # 原始LPL损失配置
    log_and_print(f"  原始LPL损失：")

    # 自适应LPL损失配置
    log_and_print(f"  自适应LPL损失：")
    log_and_print(f"    pgd_nums = {adaptive_lpl_loss.pgd_nums} (基础扰动步数)")
    log_and_print(f"    alpha = {adaptive_lpl_loss.alpha} (基础扰动强度)")
    log_and_print(f"    特点: 自动根据类别分布和样本难度调整扰动参数")
    

    # 添加调试信息控制变量
    debug_print = args.get('debug_print', False)  # 如果args中没有设置，默认为False

    print(f"Using device: {device}")
    
    prefix = os.path.join(os.path.dirname(__file__), "..", "..", "/root/autodl-tmp/antifraud/data/")
    print('loading data...')
    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)
    
    best_model_id = 0
    np.random.seed(args['seed'])
    random.seed(args['seed'])
    
    if args['dataset'] == 'yelp' or args['dataset'] == 'CCFD':
        assert args['dataset'] != 'CCFD', 'Due to confidentiality agreements, we are unable to provide the CCFD data.'
        
        index = list(range(len(labels)))
        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels, test_size=args['test_size'], random_state=2, shuffle=True)
        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val, stratify=y_train_val, test_size=args['val_size'], random_state=2, shuffle=True)
        lambda_cl = 0.7
        drop_edge_rate_1 = 0.2
        drop_edge_rate_2 = 0.3
        use_pot = False
        
        # 分离正负样本
        train_pos, train_neg = pos_neg_split(idx_train, y_train)
        
        # 随机选取一个正样本和一个负样本
        np.random.shuffle(train_pos)  # 随机打乱正样本列表
        np.random.shuffle(train_neg)  # 随机打乱负样本列表
        one_pos = [train_pos[0]]  # 随机选择一个正样本
        one_neg = [train_neg[0]]  # 随机选择一个负样本
        
        # 将其余所有样本放入无标签池
        unlabeled_pool = []
        for idx in idx_train:
            if idx != one_pos[0] and idx != one_neg[0]:
                unlabeled_pool.append(idx)
        
        # 更新训练集和标签
        idx_train = one_pos + one_neg
        y_train = labels[idx_train]
        
        # 输出训练集中无标签数据和有标签数据数量
        print(f"仅保留一个随机正样本和一个随机负样本，其余全部设为无标签")
        print(f"有标签数据数量: {len(idx_train)}")
        print(f"无标签数据数量: {len(unlabeled_pool)}")
        print(f"其中正样本: {y_train.count(1) if isinstance(y_train, list) else np.sum(y_train == 1)}")
        print(f"其中负样本: {y_train.count(0) if isinstance(y_train, list) else np.sum(y_train == 0)}")
        train_unlabeled = unlabeled_pool  # 设置无标签训练集
        
    elif args['dataset'] == 'amazon':
        # 先按原来方式分割有标签数据
        labeled_index = list(range(3305, len(labels)))
        lambda_cl = 0.7
        drop_edge_rate_1 = 0.2
        drop_edge_rate_2 = 0.3
        use_pot = True
        idx_train_val, idx_test, y_train_val, y_test = train_test_split(
            labeled_index, 
            labels[3305:], 
            stratify=labels[3305:], 
            test_size=args['test_size'], 
            random_state=2, 
            shuffle=True
        )
        idx_train, idx_val, y_train, y_val = train_test_split(
            idx_train_val, 
            y_train_val, 
            stratify=y_train_val, 
            test_size=args['val_size'], 
            random_state=2, 
            shuffle=True
        )
      
        # 先获取原始无标签数据
        unlabeled_pool = list(range(0, 3305))
        original_train = idx_train.copy()  # 保存原始训练集
        
        # 获取正负样本
        train_pos, train_neg = pos_neg_split(idx_train, y_train)
        
        # 随机选取一个正样本和一个负样本
        np.random.shuffle(train_pos)  # 随机打乱正样本列表
        np.random.shuffle(train_neg)  # 随机打乱负样本列表
        one_pos = [train_pos[0]]  # 随机选择一个正样本
        one_neg = [train_neg[0]]  # 随机选择一个负样本
        
        # 将剩余的有标签数据转为无标签
        convert_to_unlabel = list(set(original_train) - set(one_pos + one_neg))
        train_unlabeled = unlabeled_pool + convert_to_unlabel  # 新的无标签训练集
        
        # 更新训练集和标签
        idx_train = one_pos + one_neg
        y_train = labels[idx_train]
        
        # 输出训练集中无标签数据和有标签数据数量
        print(f"仅保留一个随机正样本和一个随机负样本，其余全部设为无标签")
        print(f"有标签数据数量: {len(idx_train)}")
        print(f"无标签数据数量: {len(train_unlabeled)}")
        print(f"其中正样本: {y_train.count(1) if isinstance(y_train, list) else np.sum(y_train == 1)}")
        print(f"其中负样本: {y_train.count(0) if isinstance(y_train, list) else np.sum(y_train == 0)}")
    
    # 这里正负样本可能会变化，所以重新计算
    train_pos, train_neg = pos_neg_split(idx_train, y_train)
    
    def nt_xent_loss(z_i, z_j, temperature=0.01):
            """
            NT-Xent Loss (Normalised Temperature-scaled Cross Entropy Loss)
            
            :param z_i: Tensor, representations of the first augmented view.
            :param z_j: Tensor, representations of the second augmented view.
            :param temperature: Float, temperature scaling factor for the loss function.
            """
            # Normalize the feature vectors
            z_i = F.normalize(z_i, dim=-1)
            z_j = F.normalize(z_j, dim=-1)
            
            # Concatenate the features from both views
            representations = torch.cat([z_i, z_j], dim=0)
            
            # Compute similarity matrix
            sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
            
            # Create labels for positive and negative pairs
            labels = torch.cat([torch.arange(z_i.size(0)).to(device) for _ in range(2)], dim=0)
            masks = labels[:, None] == labels[None, :]
            
            # Mask out self-similarity terms
            mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool).to(device)
            sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
            masks = masks[mask_diag].view(labels.size(0), -1)
            
            # Compute the InfoNCE loss
            nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
            denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
            loss = -torch.log(nominator / denominator).mean()
            
            return loss

    def generate_contrastive_pairs(batch_nodes, labels, feat_data):
            """
            根据给定的batch nodes生成正样本对和负样本对。
            
            :param batch_nodes: 当前批次中的节点索引列表
            :param labels: 节点标签
            :param feat_data: 节点特征数据
            :return: 一个包含(positive_pairs, negative_pairs)的元组
            """
            positive_pairs = []
            negative_pairs = []
            
            # 将CUDA张量转移到CPU，转换为NumPy数组
            if isinstance(labels, torch.Tensor) and labels.is_cuda:
                labels_cpu = labels.cpu().numpy()
            else:
                labels_cpu = labels

            # 确保batch_nodes也在CPU上
            if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
                batch_nodes_cpu = batch_nodes.cpu().numpy()
            else:
                batch_nodes_cpu = batch_nodes

            for node in batch_nodes_cpu:
                # 正样本对：假设同类别节点作为正样本
                same_class_nodes = np.where(labels_cpu == labels_cpu[node])[0]
                if len(same_class_nodes) > 1:
                    pos_pair = np.random.choice(same_class_nodes[same_class_nodes != node], 1)[0]
                    positive_pairs.append((node, pos_pair))

                # 负样本对：随机选取不同类别的节点
                diff_class_nodes = np.where(labels_cpu != labels_cpu[node])[0]
                if len(diff_class_nodes) > 0:
                    neg_pair = np.random.choice(diff_class_nodes, 1)[0]
                    negative_pairs.append((node, neg_pair))
            
            return positive_pairs, negative_pairs


    # 然后将encoder传入模型，并传入共享扰动参数
    gnn_model_1 = multi_HOGRL_Model(               # 传入encoder
        in_feat=feat_data.shape[1], 
        out_feat=2, 
        relation_nums=len(edge_indexs),
        hidden=args['emb_size'], 
        drop_rate=args['drop_rate'],
        weight=args['weight'], 
        num_layers=args['layers'],
        layers_tree=args['layers_tree'],
        temperature=0.5,
        dataset=args['dataset']
    ).to(device)
    

    for edge_index in edge_indexs:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]

    # labels = torch.tensor(labels).to(device)
    feat_data = torch.tensor(feat_data).float().to(device)

    # 初始化 - 更关注正样本AUC (移除perturbation_module参数)
    optimizer_1 = torch.optim.Adam(
        list(gnn_model_1.parameters()), 
        lr=0.002, #0.003 # 0.001
        weight_decay=2e-5 #3e-5 #5e-6
    )

    
    batch_size = args['batch_size']

    best_val_auc = 0.0
    best_model_state = None
    best_test_auc = 0.0
    
    print('generating augmented views...')
    # 使用不同的权重生成两个视图
    aug_type1 = 'degree'
    aug_type2 = 'degree' 
    # 使用指定的增强类型生成视图
    feat_data1, edge_index_1 = get_augmented_view(
        edge_indexs,
        feat_data,
        aug_type=aug_type1,  # 使用指定的第一个增强类型
        drop_rate=drop_edge_rate_1
    )
    
    feat_data2, edge_index_2 = get_augmented_view(
        edge_indexs,
        feat_data,
        aug_type=aug_type2,  # 使用指定的第二个增强类型
        drop_rate=drop_edge_rate_2
    )
    
    # 将增强的图结构移到设备上
    for edge_index in edge_index_1:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]
    for edge_index in edge_index_2:
        edge_index[0] = edge_index[0].to(device)
        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]
    
    print('training...')
    # 添加新的参数
    if 'mu' not in args:
        args['mu'] = 1.0  # 默认值
    if 'mu_rampup' not in args:
        args['mu_rampup'] = True  # 默认启用rampup
    if 'consistency_rampup' not in args:
        args['consistency_rampup'] = None  # 默认使用总epoch数
    if 'overwrite_viz' not in args:
        args['overwrite_viz'] = True  # 默认覆盖之前的图片

    # 添加RNC相关配置参数，只保留温度参数
    if 'clustering_temperature' not in args:
        args['clustering_temperature'] = 0.8  # 聚类温度参数

    # 添加SDA相关的配置参数
    if 'sda_projections' not in args:
        args['sda_projections'] = 50  # SDA使用的投影数量

    # 记录所有参数到日志文件
    log_and_print("=" * 50)
    log_and_print(f"实验开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    log_and_print(f"设备: {device}")
    log_and_print("=" * 50)
    log_and_print("实验参数:")
    for key, value in sorted(args.items()):
        log_and_print(f"  {key}: {value}")
    log_and_print("=" * 50)
    log_and_print(f"数据集: {args['dataset']}")
    log_and_print(f"训练集大小: {len(idx_train)}")
    log_and_print(f"验证集大小: {len(idx_val)}")
    log_and_print(f"测试集大小: {len(idx_test)}")
    log_and_print(f"正样本比例: {sum(y_train)/len(y_train):.4f}")
    log_and_print(f"增强类型: {aug_type1}, {aug_type2}")
    log_and_print(f"边删除率: {drop_edge_rate_1}, {drop_edge_rate_2}")
    log_and_print(f"伪标签平衡: {'启用' if balance_pseudo_labels else '禁用'}")
    log_and_print("=" * 50)

    # 创建保存各种损失的列表，用于绘制损失曲线
    epoch_list = []  # 记录epoch数
    total_loss_list = []  # 总损失
    classification_loss_list = []  # 分类损失
    contrastive_loss_list = []  # 对比损失
    topo_loss_list = []  # 拓扑损失
    sda_loss_list = []  # 球面投影损失
    rnc_loss_list = []  # RNC损失
    clustering_loss_list = []  # 聚类损失

    # 初始化每个epoch的损失累积
    epoch_cls_loss = 0
    epoch_contrastive_loss = 0
    epoch_consistency_loss = 0
    epoch_clustering_loss = 0
    epoch_topo_loss = 0
    epoch_sda_loss = 0  # 新增：SDA损失累积
    epoch_unlabeled_contrast_loss = 0  # 添加此行
    epoch_rnc_loss = 0  # 添加RNC损失累积
    epoch_total_loss = 0
    epoch_pseudo_label_loss = 0
    # 添加伪标签正负样本计数
    epoch_pseudo_pos_count = 0
    epoch_pseudo_neg_count = 0
    num_batches = 0
    pseudo_lpl_loss = 0
    clustering_loss = 0
    # 添加LPL损失的累积变量
    epoch_lpl_loss = 0
    
    # 添加自适应LPL统计信息记录变量
    epoch_adaptive_class0_steps = []
    epoch_adaptive_class1_steps = []
    epoch_adaptive_class0_alpha = []
    epoch_adaptive_class1_alpha = []
    epoch_class_counts = []
    epoch_grad_magnitudes = []

    # 固定使用我们的样本
    sampled_idx_train = idx_train.copy()  # 一个正样本和一个负样本
        
    rd.shuffle(train_unlabeled)
        
    # 正样本和负样本
    pos_samples = [idx for idx in sampled_idx_train if labels[idx] == 1]  # 一个正样本
    neg_samples = [idx for idx in sampled_idx_train if labels[idx] == 0]  # 一个负样本
    # 新增: 创建列表记录每个epoch的伪标签准确率
    all_epoch_pseudo_accuracies = []
    all_epoch_pos_accuracies = []
    all_epoch_neg_accuracies = []
    all_epoch_pseudo_sample_counts = []  # 记录每个epoch的伪标签数量

    for epoch in range(args['num_epochs']):
        # 模型设置为训练模式
        gnn_model_1.train()
        loss = 0
        
        # 获取当前epoch的mu值
        current_mu = get_current_mu(epoch, args)
        
        # 打印epoch开始时的数据统计
        if debug_print:
            print(f"\nEpoch {epoch} 开始:")
            print(f"无标签数据总量: {len(train_unlabeled)}")
            print(f"有标签数据总量: {len(sampled_idx_train)}")
            print(f"其中正样本数量: {len(pos_samples)}")  # 应该是1
            print(f"其中负样本数量: {len(neg_samples)}")  # 应该是1
        
        # 重置epoch统计变量
        epoch_cls_loss = 0
        epoch_contrastive_loss = 0
        epoch_consistency_loss = 0
        epoch_clustering_loss = 0
        epoch_unlabeled_contrast_loss = 0
        epoch_pseudo_label_loss = 0
        epoch_total_loss = 0
        epoch_lpl_loss = 0
        # 重置伪标签正负样本计数
        epoch_pseudo_pos_count = 0
        epoch_pseudo_neg_count = 0
        num_batches = 0

        if debug_print:
            print(f"总batch数量: {num_batches}")

        # 计算总batch数量
        batch_size = args['batch_size']
        num_batches = max(1, (len(train_unlabeled) + batch_size - 1) // batch_size)
        
        if debug_print:
            print(f"总batch数量: {num_batches}")

        for batch in range(num_batches):
            # 获取batch数据：优先使用有标签数据
            batch_nodes = []
            
            # 添加所有的有标签样本(一个正样本和一个负样本)
            batch_nodes.extend(sampled_idx_train)
            
            # 添加无标签数据直到达到batch_size
            remaining_spots = batch_size - len(sampled_idx_train)
            if remaining_spots > 0:
                u_start = batch * remaining_spots
                u_end = min((batch + 1) * remaining_spots, len(train_unlabeled))
                if u_start < len(train_unlabeled):
                    batch_unlabeled = train_unlabeled[u_start:u_end]
                    # 如果不够，循环使用无标签数据
                    if len(batch_unlabeled) < remaining_spots:
                        needed = remaining_spots - len(batch_unlabeled)
                        batch_unlabeled.extend(train_unlabeled[:needed])
                    batch_nodes.extend(batch_unlabeled)
            
            # 分离batch中的有标签和无标签数据
            batch_labeled = [node for node in batch_nodes if node in sampled_idx_train]
            batch_unlabeled = [node for node in batch_nodes if node not in sampled_idx_train]
            
            
            unlabeled_nodes_tensor = torch.tensor(batch_unlabeled, device=device)
            batch_nodes_tensor = torch.tensor(batch_nodes, dtype=torch.long, device=device)
            batch_label = torch.tensor(labels[np.array(batch_labeled)]).long().to(device)
            
            #原始图输出
            original_out, original_h = gnn_model_1(feat_data, edge_indexs)
            
            #两个增强视图输出
            out1, h1 = gnn_model_1(feat_data1, edge_index_1)
            out2, h2 = gnn_model_1(feat_data2, edge_index_2)


            # 只对有标签数据计算分类损失
            labeled_nodes_tensor = torch.tensor(batch_labeled, device=device)
            
            classification_loss_1 = F.nll_loss(out1[labeled_nodes_tensor], batch_label)
            classification_loss_2 = F.nll_loss(out2[labeled_nodes_tensor], batch_label)
        
            # 只对有标签数据生成正负样本对和计算对比损失
            positive_pairs, negative_pairs = generate_contrastive_pairs(batch_labeled, labels, feat_data)
            #输出positive_pairs的形状
            print(f"positive_pairs: {positive_pairs}")
            z_i_1 = h1[torch.tensor([p[0] for p in positive_pairs], device=device)]
            print(f"z_i_1: {z_i_1}")
            z_j_1 = h1[torch.tensor([p[1] for p in positive_pairs], device=device)]
            z_i_2 = h2[torch.tensor([p[0] for p in positive_pairs], device=device)]
            z_j_2 = h2[torch.tensor([p[1] for p in positive_pairs], device=device)]
            
            contrastive_loss_1 = nt_xent_loss(z_i_1, z_j_1)
            contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
            print(f"contrastive_loss_1: {contrastive_loss_1}")
            print(f"contrastive_loss_2: {contrastive_loss_2}")
            # 对所有数据计算组合一致性损失
            consistency_loss = F.mse_loss(h1[batch_nodes_tensor], h2[batch_nodes_tensor])

            if len(batch_nodes) > 0:  
                
                # 对两个增强视图分别进行鲁棒节点聚类
                # 聚类过程不需要跟踪中间计算的梯度，但最终聚类分配需要梯度
                h_orig_unlabeled = original_h[unlabeled_nodes_tensor]
                h1_unlabeled = h1[unlabeled_nodes_tensor]
                h2_unlabeled = h2[unlabeled_nodes_tensor]
                
                # 获取有标签样本的特征，用于固定聚类中心
                labeled_nodes_tensor = torch.tensor(batch_labeled, device=device)
                labeled_features_orig = original_h[labeled_nodes_tensor]
                labeled_classes = torch.tensor(labels[np.array(batch_labeled)], device=device)
                
                # 在前10个epoch中固定聚类中心为已有的正常和欺诈样本
                if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                    
                    # 使用有标签数据初始化原始图的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=args["clustering_temperature"],
                        max_iterations=10,
                        labeled_features=labeled_features_orig,  # 传入有标签样本特征
                        labeled_classes=labeled_classes       # 传入有标签样本类别
                    )
             
                elif use_clustering_pseudo_labels:
                    # 在第10个epoch时记录转为自由聚类的信息
                    if epoch == fixed_cluster_epochs and batch == 0:
                        log_and_print(f"\n【Epoch {epoch}】转为自由聚类模式")
                        log_and_print(f"  不再使用有标签样本固定聚类中心")
                    
                    # 10个epoch后，让聚类算法自由寻找更好的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=args["clustering_temperature"],
                        max_iterations=10
                    )

           

                # 创建合并的特征和分配
                all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
                all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
                
                #计算统一的聚类损失
                clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                    all_features, 
                    all_assignments, 
                    centroids_orig  # 使用原始图的聚类中心
                )
                
                # 伪标签生成逻辑：取消置信度筛选，所有无标签样本均参与
                # 聚类结果：多数类为负样本(0)，少数类为正样本(1)
                # 概率对齐：确保用于融合的聚类概率列顺序为 [P(负), P(正)]
                with torch.no_grad():
                    final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                    
                    # 确定伪标签来源和计算逻辑
                    if use_original_pseudo_labels and use_clustering_pseudo_labels:
                        # 场景1: 模型输出 + 聚类结果 融合
                        orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                        orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1) # 模型输出概率 [P(负), P(正)]

                        # cluster_assignments_orig 是 [P(属聚类0), P(属聚类1)]
                        # 确定聚类0和聚类1哪个是多数 (负)，哪个是少数 (正)
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                        
                        aligned_cluster_probs = cluster_assignments_orig.clone()
                        if count_c0 < count_c1: 
                            # 聚类0是少数类 (应映射为正样本, 标签1)
                            # 聚类1是多数类 (应映射为负样本, 标签0)
                            # 调整列使 aligned_cluster_probs 为 [P(负=聚类1), P(正=聚类0)]
                            aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1] # 负样本概率 = 原聚类1概率
                            aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0] # 正样本概率 = 原聚类0概率
                        # else: count_c0 >= count_c1
                            # 聚类0是多数类 (负样本, 标签0)
                            # 聚类1是少数类 (正样本, 标签1)
                            # aligned_cluster_probs 无需换列，已经是 [P(负=聚类0), P(正=聚类1)]
                        
                        combined_probs_unlabeled = (orig_probs_unlabeled + aligned_cluster_probs) / 2.0
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)

                    elif use_clustering_pseudo_labels:
                        # 场景2: 仅使用聚类结果
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1) # 初步判断样本属于哪个聚类
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()

                        if count_c0 >= count_c1: 
                            # 聚类0是多数 (负样本=0), 聚类1是少数 (正样本=1)
                            # 伪标签与 temp_cluster_hard_labels 一致 (聚类0的为0, 聚类1的为1)
                            final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                        else: 
                            # 聚类1是多数 (负样本=0), 聚类0是少数 (正样本=1)
                            # 伪标签与 temp_cluster_hard_labels 相反 (聚类0的为1, 聚类1的为0)
                            final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                    
                    elif use_original_pseudo_labels:
                        # 场景3: 仅使用模型输出
                        orig_logits_unlabeled = original_out[unlabeled_nodes_tensor]
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(orig_logits_unlabeled, dim=1)
                    
                    # 为当前批次中所有无标签样本分配伪标签
                    if final_pseudo_labels_for_batch_unlabeled.numel() > 0:
                        # consistent_high_conf_indices 是相对于 unlabeled_nodes_tensor 的索引
                        consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                        consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
                        
                        # 累积伪标签正负样本数量
                        epoch_pseudo_pos_count += torch.sum(consistent_pseudo_labels == 1).item()
                        epoch_pseudo_neg_count += torch.sum(consistent_pseudo_labels == 0).item()
                    else:
                        consistent_high_conf_indices = torch.tensor([], dtype=torch.long, device=device)
                        consistent_pseudo_labels = torch.tensor([], dtype=torch.long, device=device) # Ensure consistent_pseudo_labels is defined
                    
                    # Define num_consistent_high_conf based on the number of pseudo-labels generated
                    num_consistent_high_conf = consistent_pseudo_labels.numel()

                
                
                # 添加：为多视图一致的高置信度样本添加焦点损失伪标签训练
                # 计算伪标签损失
                pseudo_label_loss = torch.tensor(0.0, device=device)
                pseudo_lpl_loss = torch.tensor(0.0, device=device)  # 初始化LPL损失

                if num_consistent_high_conf > 0:
                    try:
                        # 获取对应的预测
                        pseudo_logits_1 = out1[unlabeled_nodes_tensor][consistent_high_conf_indices]
                        pseudo_logits_2 = out2[unlabeled_nodes_tensor][consistent_high_conf_indices]
                        
                        # 检查维度是否匹配
                        if pseudo_logits_1.size(0) == consistent_pseudo_labels.size(0) and \
                           pseudo_logits_2.size(0) == consistent_pseudo_labels.size(0):
                            
                            # 计算高置信度样本中的正负样本比例以动态调整alpha
                            pos_samples = torch.sum(consistent_pseudo_labels == 1).item()
                            total_samples = consistent_pseudo_labels.size(0)
                            
                            # 确保total_samples不为零
                            if total_samples > 0:
                                pos_ratio = pos_samples / total_samples
                                dynamic_alpha = max(0.25, min(0.95, 1.0 - pos_ratio))
                                max_gamma = 5.0
                                min_gamma = 2.0
                                max_epoch = 100
                                current_gamma = min_gamma + (max_gamma - min_gamma) * min(epoch / max_epoch, 1.0)
                                
                                # 动态调整gamma参数
                                max_gamma_focal = 5.0    # Focal Loss聚焦参数最大值
                                min_gamma_focal = 2.0    # Focal Loss聚焦参数最小值
                                max_gamma_ga = 0.9       # 梯度感知平衡参数最大值
                                min_gamma_ga = 0.5       # 梯度感知平衡参数最小值
                                max_epoch = 150
                                
                                # 随着训练进行，增加gamma_focal使模型更关注难样本
                                current_gamma_focal = min_gamma + (max_gamma - min_gamma) * min(epoch / max_epoch, 1.0)
                                
                                # 随着训练进行，增加gamma_ga使模型对类别平衡更敏感
                                current_gamma_ga = min_gamma_ga + (max_gamma_ga - min_gamma_ga) * min(epoch / max_epoch, 1.0)
                                
                                # 根据伪标签策略选择不同的损失函数
                                if epoch < -1:
                                    # 仅使用原始图输出时，使用普通的交叉熵损失
                                    pseudo_label_loss_1 = F.cross_entropy(pseudo_logits_1, consistent_pseudo_labels)
                                    pseudo_label_loss_2 = F.cross_entropy(pseudo_logits_2, consistent_pseudo_labels)
                                else:
                                    #使用GradientAwareFocalLoss

                                    # 视图1
                                    pseudo_label_loss_1 = gradient_aware_focal(
                                        pseudo_logits_1, 
                                        consistent_pseudo_labels
                                    )
                                    
                                    # 视图2
                                    pseudo_label_loss_2 = gradient_aware_focal(
                                        pseudo_logits_2, 
                                        consistent_pseudo_labels
                                    )
                                
                                # 组合两个视图的损失
                                pseudo_label_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 2
                                
                                # 计算类别分布
                                class_counts = [torch.sum(consistent_pseudo_labels == 0).item(), 
                                              torch.sum(consistent_pseudo_labels == 1).item()]
                                
                                # 确保类别分布中不包含零
                                if min(class_counts) > 0:
                                    
                                    # 计算类别频率用于逻辑调整
                                    total_samples = sum(class_counts)
                                    
                                    # 使用自适应LPL损失 - 替换原来的基础LPL损失计算
                                    # 视图1的自适应LPL损失
                                    adap_lpl_loss_1, _, _, steps_1, alphas_1 = adaptive_lpl_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                                    
                                    # 视图2的自适应LPL损失
                                    adap_lpl_loss_2, _, _, steps_2, alphas_2 = adaptive_lpl_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
                                    
                                    # 组合自适应LPL损失
                                    pseudo_lpl_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2
                                    
                                    # 收集统计信息
                                    if epoch % 5 == 0 and batch == 0:  # 每5个epoch记录一次
                                        # 记录每个类别的平均步数和扰动强度
                                        for c in range(2):
                                            cls_mask_1 = consistent_pseudo_labels == c
                                            if torch.any(cls_mask_1):
                                                avg_steps = steps_1[cls_mask_1].float().mean().item()
                                                avg_alpha = alphas_1[cls_mask_1].mean().item()
                                                
                                                if c == 0:
                                                    epoch_adaptive_class0_steps.append(avg_steps)
                                                    epoch_adaptive_class0_alpha.append(avg_alpha)
                                                else:
                                                    epoch_adaptive_class1_steps.append(avg_steps)
                                                    epoch_adaptive_class1_alpha.append(avg_alpha)
                                        
                                        # 记录类别计数和梯度幅度
                                        epoch_class_counts.append(adaptive_lpl_loss.class_counts.clone())
                                        epoch_grad_magnitudes.append(adaptive_lpl_loss.class_grad_mags.clone())
                                        
                                        # 记录统计信息
                                        log_and_print(f"\n【自适应LPL参数 (Epoch {epoch}, Batch 0)】")
                                        log_and_print(f"  类别0 平均步数: {epoch_adaptive_class0_steps[-1]:.2f}")
                                        log_and_print(f"  类别1 平均步数: {epoch_adaptive_class1_steps[-1]:.2f}")
                                        log_and_print(f"  类别0 平均扰动强度: {epoch_adaptive_class0_alpha[-1]:.4f}")
                                        log_and_print(f"  类别1 平均扰动强度: {epoch_adaptive_class1_alpha[-1]:.4f}")
                                        
                                        # 显示类别统计和梯度信息
                                        class_counts = adaptive_lpl_loss.class_counts.cpu().numpy()
                                        grad_mags = adaptive_lpl_loss.class_grad_mags.cpu().numpy()
                                        
                                        log_and_print(f"  类别计数: [{class_counts[0]:.1f}, {class_counts[1]:.1f}]")
                                        log_and_print(f"  梯度幅度: [{grad_mags[0]:.4f}, {grad_mags[1]:.4f}]")
                        

                            else:
                                log_and_print(f"警告: 没有有效的高置信度样本，跳过伪标签损失计算")
                    except Exception as e:
                        log_and_print(f"警告: 计算伪标签损失时出错: {e}")
                        pseudo_label_loss = torch.tensor(0.0, device=device)
                        pseudo_lpl_loss = torch.tensor(0.0, device=device)

            else:
                # 如果没有无标签节点，则设置所有相关变量为0
                clustering_loss = torch.tensor(0.0, device=device)
                unlabeled_contrast_loss = torch.tensor(0.0, device=device)
                unlabeled_contrast_loss_1 = torch.tensor(0.0, device=device)
                unlabeled_contrast_loss_2 = torch.tensor(0.0, device=device)
                pseudo_label_loss = torch.tensor(0.0, device=device)

            # 设置topo_loss为0
            topo_loss = 0.0

            # 设置sda_loss为0
            sda_loss = torch.tensor(0.0, device=device)
            
            # 计算总损失 - 整合所有损失
            total_loss = (classification_loss_1 + classification_loss_2) / 2 + \
                        (contrastive_loss_1 + contrastive_loss_2) / 2 + \
                        current_mu * consistency_loss + \
                        current_mu * pseudo_label_loss + \
                        current_mu * pseudo_lpl_loss + \
                        current_mu * clustering_loss
                        # \ # current_mu * pseudo_lpl_loss 
                        # current_mu * pseudo_label_loss

            # 计算单独的损失值用于记录
            cls_loss = (classification_loss_1 + classification_loss_2) / 2
            contrastive_loss = (contrastive_loss_1 + contrastive_loss_2) / 2
            
            # 累积各种损失
            epoch_cls_loss += cls_loss.item()
            epoch_contrastive_loss += contrastive_loss.item()
            epoch_consistency_loss += consistency_loss.item()
            epoch_clustering_loss += clustering_loss.item() if isinstance(clustering_loss, torch.Tensor) else clustering_loss
            # epoch_topo_loss += topo_loss if isinstance(topo_loss, float) else topo_loss.item()
            # epoch_sda_loss += sda_loss.item() if isinstance(sda_loss, torch.Tensor) else sda_loss  # 添加SDA损失累积
            epoch_unlabeled_contrast_loss += unlabeled_contrast_loss.item()  # 添加此行
            epoch_rnc_loss += 0  # 添加RNC损失累积
            epoch_pseudo_label_loss += pseudo_label_loss.item() if isinstance(pseudo_label_loss, torch.Tensor) else pseudo_label_loss  # 修正伪标签损失累积
            epoch_total_loss += total_loss.item()
            # 添加LPL损失的累积
            epoch_lpl_loss += pseudo_lpl_loss.item() if isinstance(pseudo_lpl_loss, torch.Tensor) else pseudo_lpl_loss
       
            
            num_batches += 1
            
            optimizer_1.zero_grad()
            total_loss.backward()
            optimizer_1.step()
            
            loss += total_loss.item()

        # 在epoch结束时计算平均损失
        epoch_cls_loss /= num_batches
        epoch_contrastive_loss /= num_batches
        epoch_consistency_loss /= num_batches
        epoch_clustering_loss /= num_batches
        epoch_unlabeled_contrast_loss /= num_batches
        epoch_pseudo_label_loss /= num_batches
        epoch_total_loss /= num_batches
        # 添加LPL损失的平均值计算
        epoch_lpl_loss /= num_batches

        # 更新日志记录，添加GADice和GACE损失相关信息
        log_message = (f"Epoch {epoch}: 总损失={epoch_total_loss:.4f}, 分类损失={epoch_cls_loss:.4f}, "
                      f"对比损失={epoch_contrastive_loss:.4f}, 聚类损失={epoch_clustering_loss:.4f}, ")
        
        # 添加其余损失信息
        log_message += (f"一致性损失={current_mu * epoch_consistency_loss:.4f}, "
                      f"无标签对比损失={current_mu * epoch_unlabeled_contrast_loss:.4f}, "
                      f"mu={current_mu:.4f}")
        
        log_and_print(log_message)
        
        # 输出损失信息 - 更新为包含SDA损失
        log_and_print(f'\nEpoch {epoch} Loss Summary:')
        log_and_print(f'  分类损失: {epoch_cls_loss:.4f}')
        log_and_print(f'  对比损失: {epoch_contrastive_loss:.4f}')
        log_and_print(f'  聚类损失: {epoch_clustering_loss:.4f}')
        log_and_print(f'  一致性损失: {epoch_consistency_loss:.4f}')
        log_and_print(f'  无标签对比损失: {epoch_unlabeled_contrast_loss:.4f}')
        # 确保epoch_pseudo_label_loss已定义
        if 'epoch_pseudo_label_loss' not in locals():
            epoch_pseudo_label_loss = 0.0
        log_and_print(f'  伪标签损失: {epoch_pseudo_label_loss / max(1, num_batches):.4f}')  # 添加此行
        log_and_print(f'  LPL损失: {epoch_lpl_loss / max(1, num_batches):.4f}')
        log_and_print(f'  总损失: {epoch_total_loss:.4f}')
        
        # 添加自适应LPL损失统计信息
        if epoch % 5 == 0 and len(epoch_adaptive_class0_steps) > 0 and len(epoch_adaptive_class1_steps) > 0:
            # 计算平均步数和扰动强度
            avg_class0_steps = sum(epoch_adaptive_class0_steps) / max(len(epoch_adaptive_class0_steps), 1)
            avg_class1_steps = sum(epoch_adaptive_class1_steps) / max(len(epoch_adaptive_class1_steps), 1)
            avg_class0_alpha = sum(epoch_adaptive_class0_alpha) / max(len(epoch_adaptive_class0_alpha), 1)
            avg_class1_alpha = sum(epoch_adaptive_class1_alpha) / max(len(epoch_adaptive_class1_alpha), 1)
            
            # 输出自适应LPL损失统计信息
            log_and_print(f'\n【自适应LPL损失统计信息】')
            log_and_print(f'  类别0(多数类) 平均扰动步数: {avg_class0_steps:.2f}')
            log_and_print(f'  类别1(少数类) 平均扰动步数: {avg_class1_steps:.2f}')
            log_and_print(f'  类别0(多数类) 平均扰动强度: {avg_class0_alpha:.4f}')
            log_and_print(f'  类别1(少数类) 平均扰动强度: {avg_class1_alpha:.4f}')
            
            # 输出类别统计和梯度信息
            if len(epoch_class_counts) > 0 and len(epoch_grad_magnitudes) > 0:
                avg_class_counts = torch.stack(epoch_class_counts).mean(dim=0)
                avg_grad_mags = torch.stack(epoch_grad_magnitudes).mean(dim=0)
                
                # 归一化处理
                normalized_counts = avg_class_counts / (avg_class_counts.sum() + 1e-8)
                normalized_grads = F.softmax(avg_grad_mags, dim=0)
                
                log_and_print(f'  类别频率分布: [{normalized_counts[0].item():.3f}, {normalized_counts[1].item():.3f}]')
                log_and_print(f'  类别梯度分布: [{normalized_grads[0].item():.3f}, {normalized_grads[1].item():.3f}]')
        
        # 重置自适应LPL统计信息记录变量
        epoch_adaptive_class0_steps = []
        epoch_adaptive_class1_steps = []
        epoch_adaptive_class0_alpha = []
        epoch_adaptive_class1_alpha = []
        epoch_class_counts = []
        epoch_grad_magnitudes = []

        # 记录各种损失
        epoch_list.append(epoch)
        classification_loss_list.append(epoch_cls_loss)
        contrastive_loss_list.append(epoch_contrastive_loss)
        topo_loss_list.append(0)  # 暂不使用topo_loss
        sda_loss_list.append(0)  # 暂不使用sda_loss
        rnc_loss_list.append(epoch_rnc_loss)
        clustering_loss_list.append(epoch_clustering_loss)
        total_loss_list.append(epoch_total_loss)
        

        # 验证和测试
        if epoch % 1 == 0:  # 每个epoch都验证
            val_auc, val_ap, val_f1, val_g_mean, val_acc_label0, val_acc_label1, val_acc_overall = test(idx_val, y_val, gnn_model_1, feat_data, edge_indexs)
            log_and_print(f'Epoch: {epoch}, 验证集 AUC: {val_auc:.4f}, AP: {val_ap:.4f}, F1: {val_f1:.4f}, G-mean: {val_g_mean:.4f}, Label 0 ACC: {val_acc_label0:.4f}, Label 1 ACC: {val_acc_label1:.4f}, Overall ACC: {val_acc_overall:.4f}')

            # 在第100轮后调整学习率和weight_decay
            # if epoch == 100:
            #     # 降低学习率从0.002到0.001
            #     for param_group in optimizer_1.param_groups:
            #         param_group['lr'] = 0.001
            #         param_group['weight_decay'] = 5e-5  # 增大weight_decay从3e-5到5e-5
            #     log_and_print(f"第{epoch}轮后调整参数: 学习率降为0.001, weight_decay增大为5e-5")
            
            # 在这里调用scheduler.step，传入验证指标
            # 在scheduler.step之后添加以下代码
            # if epoch > 20:
            #     # 记录更新前的学习率和weight_decay
            #     prev_lr = optimizer_1.param_groups[0]['lr']
            #     prev_wd = optimizer_1.param_groups[0]['weight_decay']
                
            #     # 更新学习率
            #     scheduler.step(val_auc)
                
            #     # 获取更新后的当前学习率
            #     current_lr = optimizer_1.param_groups[0]['lr']
                
            #     # 学习率确实变化了才调整weight_decay
            #     if current_lr != prev_lr:
            #         # 保持weight_decay与学习率的比例
            #         initial_lr = 0.003
            #         initial_wd = 3e-5
            #         wd_to_lr_ratio = initial_wd / initial_lr
            #         new_wd = current_lr * wd_to_lr_ratio
                    
            #         # 更新weight_decay
            #         for param_group in optimizer_1.param_groups:
            #             param_group['weight_decay'] = new_wd
                    
            #         log_and_print(f"学习率调整: {prev_lr:.2e} -> {current_lr:.2e}")
            #         log_and_print(f"权重衰减调整: {prev_wd:.2e} -> {new_wd:.2e}")
            # # 打印当前学习率和weight_decay
            # current_lr = optimizer_1.param_groups[0]['lr']
            # current_wd = optimizer_1.param_groups[0]['weight_decay']
            # log_and_print(f"当前学习率: {current_lr:.2e}, weight_decay: {current_wd:.2e}")

            # 保存最佳模型
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model_state = gnn_model_1.state_dict()

        # 在最后一个epoch输出最终验证结果
        if epoch == args['num_epochs'] - 1:
            log_and_print(f'最终验证集 AUC: {best_val_auc:.4f}')

        # 加载最佳模型并进行测试
        gnn_model_1.load_state_dict(best_model_state)  
        test_auc, test_ap, test_f1, test_g_mean, test_acc_label0, test_acc_label1, test_acc_overall = test(idx_test, y_test, gnn_model_1, feat_data, edge_indexs)
        
        # 输出每个epoch的伪标签正负样本统计
        if epoch_pseudo_pos_count > 0 or epoch_pseudo_neg_count > 0:
            avg_pos_samples = epoch_pseudo_pos_count / max(num_batches, 1)
            avg_neg_samples = epoch_pseudo_neg_count / max(num_batches, 1)
            total_avg = avg_pos_samples + avg_neg_samples
            pos_ratio = avg_pos_samples / total_avg * 100 if total_avg > 0 else 0
            neg_ratio = avg_neg_samples / total_avg * 100 if total_avg > 0 else 0
            
            log_and_print(f"\n【Epoch {epoch} 伪标签统计】")
            log_and_print(f"  平均每个batch伪标签数量: {total_avg:.2f}")
            log_and_print(f"  平均正样本: {avg_pos_samples:.2f} ({pos_ratio:.1f}%)")
            log_and_print(f"  平均负样本: {avg_neg_samples:.2f} ({neg_ratio:.1f}%)")
            log_and_print(f"  正负样本比: {avg_pos_samples:.2f}:{avg_neg_samples:.2f}")
        else:
            log_and_print(f"\n【Epoch {epoch} 伪标签统计】无伪标签生成")
        
        # 更新最佳测试AUC
        if test_auc > best_test_auc:
            best_test_auc = test_auc
        # 在验证时对无标签数据进行聚类可视化，使用RNC聚类而非K-means
        with torch.no_grad():
            # 获取测试数据的嵌入
            test_tensor = torch.tensor(idx_test, device=device)
            _, h_test = gnn_model_1(feat_data, edge_indexs)
            h_test = h_test[test_tensor].clone().detach()
            
            # 对测试数据的嵌入进行t-SNE降维
            tsne = TSNE(n_components=3, random_state=42)
            h_test_tsne = tsne.fit_transform(h_test.cpu().numpy())
            
            # 创建3D图形
            plt.figure(figsize=(12, 10))
            ax = plt.axes(projection='3d')
            
            # 获取测试集的真实标签
            y_test_np = np.array(y_test)
            
            # 使用不同颜色表示正负样本（橙色表示负样本，蓝色表示正样本）
            colors = ['#FF7F0E', '#1F77B4']  # 橙色和蓝色
            markers = ['o', '^']  # 圆形表示负样本，三角形表示正样本
            
            # 分别绘制正负样本
            for label in [0, 1]:
                mask = y_test_np == label
                ax.scatter(
                    h_test_tsne[mask, 0], 
                    h_test_tsne[mask, 1], 
                    h_test_tsne[mask, 2],
                    c=colors[label],
                    marker=markers[label],
                    s=50 if label == 1 else 30,  # 正样本点稍大
                    alpha=0.7, 
                    label=f"{'Positive' if label == 1 else 'Negative'} samples"
                )
            
            # 添加图例和标题
            ax.legend(loc='upper right', fontsize=12)
            ax.set_title(f'Test Set Embedding Visualization (Epoch {epoch})\n'
                         f'Total Samples: {len(y_test)} | '
                         f'Positive: {np.sum(y_test_np == 1)} | '
                         f'Negative: {np.sum(y_test_np == 0)}', 
                         fontsize=14)
            
            # 设置轴标签
            ax.set_xlabel('t-SNE1', fontsize=12)
            ax.set_ylabel('t-SNE2', fontsize=12)
            ax.set_zlabel('t-SNE3', fontsize=12)
            
            # 设置最佳视角
            ax.view_init(elev=25, azim=45)
            
            # 保存图像
            embed_viz_dir = '/root/autodl-tmp/hali/antifraud/log_zp4/fig'
            os.makedirs(embed_viz_dir, exist_ok=True)
            
            # 改为使用单一文件名，覆盖之前的图像
            embed_viz_file = os.path.join(embed_viz_dir, f'test_embed_viz_latest.png')
            plt.tight_layout()
            plt.savefig(embed_viz_file, dpi=300, bbox_inches='tight', facecolor='white')
            plt.close()
        log_and_print(f'测试结果: AUC={test_auc:.4f}, AP={test_ap:.4f}, F1={test_f1:.4f}, G-mean={test_g_mean:.4f}, 最佳测试 AUC: {best_test_auc:.4f}, Label 0 ACC: {test_acc_label0:.4f}, Label 1 ACC: {test_acc_label1:.4f}, Overall ACC: {test_acc_overall:.4f}')    



    # 添加实验总结信息
    log_and_print("\n" + "=" * 50)
    log_and_print(f"实验结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    log_and_print(f"最佳验证集AUC: {best_val_auc:.4f}")
    log_and_print(f"最佳测试集AUC: {best_test_auc:.4f}")
    log_and_print(f"最佳测试集ACC: {test_acc_overall:.4f}")
    log_and_print(f"最佳测试集负样本(标签0)ACC: {test_acc_label0:.4f}")
    log_and_print(f"最佳测试集正样本(标签1)ACC: {test_acc_label1:.4f}")
    log_and_print(f"最终测试集结果: AUC={test_auc:.4f}, AP={test_ap:.4f}, F1={test_f1:.4f}, G-mean={test_g_mean:.4f}, Label 0 ACC: {test_acc_label0:.4f}, Label 1 ACC: {test_acc_label1:.4f}, Overall ACC: {test_acc_overall:.4f}")
    
    # 新增: 增加伪标签质量评估记录
    if len(all_epoch_pseudo_accuracies) > 0:
        # 找出最佳epoch及其索引
        best_acc_idx = all_epoch_pseudo_accuracies.index(max(all_epoch_pseudo_accuracies))
        best_pos_acc_idx = all_epoch_pos_accuracies.index(max(all_epoch_pos_accuracies))
        best_neg_acc_idx = all_epoch_neg_accuracies.index(max(all_epoch_neg_accuracies))
        
        # 计算最多伪标签的epoch
        max_samples_idx = all_epoch_pseudo_sample_counts.index(max(all_epoch_pseudo_sample_counts))
        max_samples = all_epoch_pseudo_sample_counts[max_samples_idx]
        
        log_and_print("\n伪标签质量评估总结:")
        log_and_print(f"  最高伪标签准确率: {max(all_epoch_pseudo_accuracies):.4f} (Epoch {best_acc_idx})")
        log_and_print(f"  最终伪标签准确率: {all_epoch_pseudo_accuracies[-1]:.4f}")
        log_and_print(f"  平均伪标签准确率: {sum(all_epoch_pseudo_accuracies)/len(all_epoch_pseudo_accuracies):.4f}")
        
        # 输出正负样本伪标签准确率统计
        log_and_print(f"  正样本最高伪标签准确率: {max(all_epoch_pos_accuracies):.4f} (Epoch {best_pos_acc_idx})")
        log_and_print(f"  负样本最高伪标签准确率: {max(all_epoch_neg_accuracies):.4f} (Epoch {best_neg_acc_idx})")
        log_and_print(f"  最多伪标签样本数: {max_samples} (Epoch {max_samples_idx})")
        
        # 输出每个epoch的准确率变化
        log_and_print("\n伪标签准确率变化:")
        for i, (acc, pos_acc, neg_acc, sample_count) in enumerate(zip(
            all_epoch_pseudo_accuracies, all_epoch_pos_accuracies, 
            all_epoch_neg_accuracies, all_epoch_pseudo_sample_counts)):
            log_and_print(f"  Epoch {i}: 总体={acc:.4f}, 正样本={pos_acc:.4f}, 负样本={neg_acc:.4f}, 样本数={sample_count}")
            
    log_and_print("=" * 50)
    log_and_print(f"日志文件保存在: {log_file_path}")
    log_and_print("=" * 50)

    # 生成最终的嵌入可视化
    out, embedding = gnn_model_1(feat_data, edge_indexs)
    print('生成嵌入可视化...')
    Visualization(labels, embedding.cpu().detach(), prefix)

    
    
    log_and_print(f"当前学习率: {optimizer_1.param_groups[0]['lr']}")

    # 添加：输出多视图一致高置信度样本信息
    if 'epoch_consistent_high_conf_samples' in locals() and 'epoch_total_unlabeled_samples' in locals():
        log_and_print(f"  多视图一致高置信度样本数: {epoch_consistent_high_conf_samples}")
        avg_consistent_high_conf_percent = epoch_consistent_high_conf_samples / max(1, epoch_total_unlabeled_samples/2) * 100
        log_and_print(f"  多视图一致高置信度样本占比: {avg_consistent_high_conf_percent:.2f}%")
    else:
        log_and_print(f"  未记录多视图一致高置信度样本统计信息")


